package com.tomz.xatomic.xa;

import com.tomz.xatomic.Constant;
import com.tomz.xatomic.DynamicDataSourceContextHolder;
import org.apache.ibatis.cursor.Cursor;
import org.apache.ibatis.executor.BatchResult;
import org.apache.ibatis.session.*;
import org.mybatis.spring.SqlSessionTemplate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Connection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;

import static java.lang.reflect.Proxy.newProxyInstance;

/**
 * 分布式事务支持
 * @author ZHUFEIFEI
 */
public class XaDynamicSqlSessionTemplate extends SqlSessionTemplate {
    private final Logger log = LoggerFactory.getLogger(getClass());
    private Map<String, SqlSessionTemplate> templateMap = new ConcurrentHashMap<>();
    private SqlSession sqlSessionProxy;
    public XaDynamicSqlSessionTemplate(SqlSessionFactory sqlSessionFactory) {
        super(sqlSessionFactory);
        this.sqlSessionProxy = (SqlSession) newProxyInstance(
                SqlSessionFactory.class.getClassLoader(),
                new Class[] { SqlSession.class },
                (proxy, method, args) -> {
                        log.debug("xa-sqlSession-proxy => {}", method.getName());
                        return method.invoke(this.determineSqlSessionTemplate(), args);
                });
    }

    public void add(String name, SqlSessionTemplate sqlSessionTemplate) {
        this.templateMap.put(getKey(name), sqlSessionTemplate);
    }

    public Map<String, SqlSessionTemplate> sqlSessionTemplates() {
        return Collections.unmodifiableMap(this.templateMap);
    }

    @Override
    public <T> T selectOne(String statement) {
        return this.sqlSessionProxy.selectOne(statement);
    }

    @Override
    public <T> T selectOne(String statement, Object parameter) {
        return this.sqlSessionProxy.selectOne(statement, parameter);
    }

    @Override
    public <K, V> Map<K, V> selectMap(String statement, String mapKey) {
        return this.sqlSessionProxy.selectMap(statement, mapKey);
    }

    @Override
    public <K, V> Map<K, V> selectMap(String statement, Object parameter, String mapKey) {
        return this.sqlSessionProxy.selectMap(statement, parameter, mapKey);
    }

    @Override
    public <K, V> Map<K, V> selectMap(String statement, Object parameter, String mapKey, RowBounds rowBounds) {
        return this.sqlSessionProxy.selectMap(statement, parameter, mapKey, rowBounds);
    }

    @Override
    public <T> Cursor<T> selectCursor(String statement) {
        return this.sqlSessionProxy.selectCursor(statement);
    }

    @Override
    public <T> Cursor<T> selectCursor(String statement, Object parameter) {
        return this.sqlSessionProxy.selectCursor(statement, parameter);
    }

    @Override
    public <T> Cursor<T> selectCursor(String statement, Object parameter, RowBounds rowBounds) {
        return this.sqlSessionProxy.selectCursor(statement, parameter, rowBounds);
    }

    @Override
    public <E> List<E> selectList(String statement) {
        return this.sqlSessionProxy.selectList(statement);
    }

    @Override
    public <E> List<E> selectList(String statement, Object parameter) {
        return this.sqlSessionProxy.selectList(statement, parameter);
    }

    @Override
    public <E> List<E> selectList(String statement, Object parameter, RowBounds rowBounds) {
        return this.sqlSessionProxy.selectList(statement, parameter, rowBounds);
    }

    @Override
    public void select(String statement, ResultHandler handler) {
        this.sqlSessionProxy.select(statement, handler);
    }

    @Override
    public void select(String statement, Object parameter, ResultHandler handler) {
        this.sqlSessionProxy.select(statement, parameter, handler);
    }

    @Override
    public void select(String statement, Object parameter, RowBounds rowBounds, ResultHandler handler) {
        this.sqlSessionProxy.select(statement, parameter, rowBounds, handler);
    }

    @Override
    public int insert(String statement) {
        return this.sqlSessionProxy.insert(statement);
    }

    @Override
    public int insert(String statement, Object parameter) {
        return this.sqlSessionProxy.insert(statement, parameter);
    }

    @Override
    public int update(String statement) {
        return this.sqlSessionProxy.update(statement);
    }

    @Override
    public int update(String statement, Object parameter) {
        return this.sqlSessionProxy.update(statement, parameter);
    }

    @Override
    public int delete(String statement) {
        return this.sqlSessionProxy.delete(statement);
    }

    @Override
    public int delete(String statement, Object parameter) {
        return this.sqlSessionProxy.delete(statement, parameter);
    }

    /**
     * 此处是SqlSessionTemplate动态切换的关键
     * Mybatis初始化时注册Mapper，此时如果就切换了SqlSessionTemplate不能达到效果
     * MapperRegistry -> MapperProxy 中有获取sqlSession
     */
    @Override
    public <T> T getMapper(Class<T> type) {
//        return this.sqlSessionProxy.getMapper(type);
        return this.getConfiguration().getMapper(type, this);
    }

    @Override
    public void clearCache() {
        this.sqlSessionProxy.clearCache();
    }

    @Override
    public Configuration getConfiguration() {
        return this.sqlSessionProxy.getConfiguration();
    }

    @Override
    public Connection getConnection() {
        return this.sqlSessionProxy.getConnection();
    }

    @Override
    public List<BatchResult> flushStatements() {
        return this.sqlSessionProxy.flushStatements();
    }

    private SqlSession determineSqlSessionTemplate() {
        return this.templateMap.getOrDefault(getKey(DynamicDataSourceContextHolder.get()), this.defaultTemplate());
    }

    private SqlSessionTemplate defaultTemplate() {
        return this.templateMap.values().iterator().next();
    }

    private String getKey(String name) {
        if (Objects.isNull(name)) {
            return  Constant.DATA_SOURCE_DEFAULT_KEY;
        }
        return name;
//        return String.format("%s%s", PREFIX, name);
    }
}
